"""
© 2021 This work is licensed under a CC-BY-NC-SA license.
Title: *"Behavioral cloning in recurrent spiking networks: A comprehensive framework"*
**Authors:** Anonymus
"""

import gym
import numpy as np
from numpy import savetxt
import gbc
import matplotlib.pyplot as plt

N, I, O, T = 500, 24, 4, 100;
shape = (N, I, O, T);

dt = .001# / T;
tau_m = 4. * dt;
tau_s = 2. * dt;
tau_ro = 3. * dt;
beta_s  = np.exp (-dt / tau_s);
beta_ro = np.exp (-dt / tau_ro);
sigma_teach = 5.0;
sigma_input = 50.0;
offT = 1;
dv = 1 / 5.;
alpha = .5;
alpha_rout = .0005;#0.1#.00002;
alpha_pg = 0.0005
Vo = -4;
h = -4;
s_inh = 20;

# Here we build the dictionary of the simulation parameters
par = {'tau_m' : tau_m, 'tau_s' : tau_s, 'tau_ro' : tau_ro, 'beta_ro' : beta_ro,
       'dv' : dv, 'alpha' : alpha, 'Vo' : Vo, 'h' : h, 's_inh' : s_inh,
       'N' : N, 'T' : T, 'dt' : dt, 'offT' : offT, 'alpha_rout' : alpha_rout,
       'sigma_input' : sigma_input, 'sigma_teach' : sigma_teach, 'shape' : shape};

# Here we init our model
gbc = gbc.GBC (par);

def test(n_episodes,gbc):
    ############## Hyperparameters ##############
    env_name = "BipedalWalker-v3"
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    max_timesteps = 1500    # max timesteps in one episode
    render = False           # render the environment
    save_gif = False        # png images are saved in gif folder

    action_std = 0.5        # constant std for action distribution (Multivariate Normal)
    K_epochs = 80           # update policy for K epochs
    eps_clip = 0.2          # clip parameter for PPO
    gamma = 0.99            # discount factor

    lr = 0.0003             # parameters for Adam optimizer
    betas = (0.9, 0.999)
    #############################################

    action_collection = np.zeros((4,0),dtype=float)
    state_collection = np.zeros((24,0),dtype=float)

    gbc.reset ()

    reward_collection = []

    for ep in range(1, n_episodes+1):
        ep_reward = 0
        state = env.reset()
        gbc.reset ()
        S_dynamics = np.zeros((N,max_timesteps),dtype=float)# []
        v_dynamics = np.zeros((N,max_timesteps),dtype=float)
        action_dynamics = np.zeros((O,max_timesteps),dtype=float)
        state_dynamics = np.zeros((I,max_timesteps),dtype=float)

        for t in range(max_timesteps):
            action, S, dJ,v = gbc.step (state, 1 )
            state, reward, done, _ = env.step(action)
            S_dynamics[:,t] = S
            v_dynamics[:,t] = v
            action_dynamics[:,t] = action
            state_dynamics[:,t] = state

            ep_reward += reward
            if render:
                env.render()
            if save_gif:
                 img = env.render(mode = 'rgb_array')
                 img = Image.fromarray(img)
                 img.save('./gif/{}.jpg'.format(t))
            if done:
                break


        reward_collection.append(ep_reward)
        ep_reward = 0
        env.close()

    return np.mean(reward_collection)


if __name__ == '__main__':
    reward = test(10,gbc)
    print(reward)
